{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {
    "collapsed": true,
    "nbsphinx": "hidden"
   },
   "outputs": [],
   "source": [
    "%load_ext autoreload\n",
    "%autoreload 2\n",
    "import os\n",
    "os.sys.path.insert(0, '/home/schirrmr/braindecode/code/braindecode/')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Cropped Manual Training Loop"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "<div class=\"alert alert-info\">\n",
    "\n",
    "Here, we show the cropped decoding when you want to write your own training loop. For more simple code with a predefined training loop and an explanation of cropped decoding in general, see the [Cropped Decoding Tutorial](./Cropped_Decoding.html).\n",
    "\n",
    "</div>"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Most of the code for cropped decoding is identical to the [Trialwise Manual Training Loop Tutorial](./Trialwise_Manual_Training_Loop.html), differences are explained in the text."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Load data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "import mne\n",
    "from mne.io import concatenate_raws\n",
    "\n",
    "# 5,6,7,10,13,14 are codes for executed and imagined hands/feet\n",
    "subject_id = 22 # carefully cherry-picked to give nice results on such limited data :)\n",
    "event_codes = [5,6,9,10,13,14]\n",
    "\n",
    "# This will download the files if you don't have them yet,\n",
    "# and then return the paths to the files.\n",
    "physionet_paths = mne.datasets.eegbci.load_data(subject_id, event_codes)\n",
    "\n",
    "# Load each of the files\n",
    "parts = [mne.io.read_raw_edf(path, preload=True,stim_channel='auto', verbose='WARNING')\n",
    "         for path in physionet_paths]\n",
    "\n",
    "# Concatenate them\n",
    "raw = concatenate_raws(parts)\n",
    "\n",
    "# Find the events in this dataset\n",
    "events = mne.find_events(raw, shortest_event=0, stim_channel='STI 014')\n",
    "\n",
    "# Use only EEG channels\n",
    "eeg_channel_inds = mne.pick_types(raw.info, meg=False, eeg=True, stim=False, eog=False,\n",
    "                   exclude='bads')\n",
    "\n",
    "# Extract trials, only using EEG channels\n",
    "epoched = mne.Epochs(raw, events, dict(hands=2, feet=3), tmin=1, tmax=4.1, proj=False, picks=eeg_channel_inds,\n",
    "                baseline=None, preload=True)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Convert data to Braindecode format"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "from braindecode.datautil.signal_target import SignalAndTarget\n",
    "# Convert data from volt to millivolt\n",
    "# Pytorch expects float32 for input and int64 for labels.\n",
    "X = (epoched.get_data() * 1e6).astype(np.float32)\n",
    "y = (epoched.events[:,2] - 2).astype(np.int64) #2,3 -> 0,1\n",
    "\n",
    "train_set = SignalAndTarget(X[:40], y=y[:40])\n",
    "valid_set = SignalAndTarget(X[40:70], y=y[40:70])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Create the model"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "For cropped decoding, we now transform the model into a model that outputs a dense time series of predictions.\n",
    "For this, we manually set the length of the final convolution layer to some length that makes the receptive field of the ConvNet smaller than the number of samples in a trial. Also, we use `to_dense_prediction_model`, which removes the strides in the ConvNet and instead uses dilated convolutions to get a dense output (see [Multi-Scale Context Aggregation by Dilated Convolutions](https://arxiv.org/abs/1511.07122) and our paper [Deep learning with convolutional neural networks for EEG decoding and visualization](https://arxiv.org/abs/1703.05051) Section 2.5.4 for some background on this)."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "from braindecode.models.shallow_fbcsp import ShallowFBCSPNet\n",
    "from torch import nn\n",
    "from braindecode.torch_ext.util import set_random_seeds\n",
    "from braindecode.models.util import to_dense_prediction_model\n",
    "\n",
    "# Set if you want to use GPU\n",
    "# You can also use torch.cuda.is_available() to determine if cuda is available on your machine.\n",
    "cuda = False\n",
    "set_random_seeds(seed=20170629, cuda=cuda)\n",
    "\n",
    "# This will determine how many crops are processed in parallel\n",
    "input_time_length = 450\n",
    "n_classes = 2\n",
    "in_chans = train_set.X.shape[1]\n",
    "# final_conv_length determines the size of the receptive field of the ConvNet\n",
    "model = ShallowFBCSPNet(in_chans=in_chans, n_classes=n_classes, input_time_length=input_time_length,\n",
    "                        final_conv_length=12).create_network()\n",
    "to_dense_prediction_model(model)\n",
    "\n",
    "if cuda:\n",
    "    model.cuda()\n",
    "\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Create cropped iterator"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "For extracting crops from the trials, Braindecode provides the  `CropsFromTrialsIterator?` class. This class needs to know the input time length of the inputs you put into the network and the number of predictions that the ConvNet will output per input. You can determine the number of predictions by passing dummy data through the ConvNet: "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "187 predictions per input/trial\n"
     ]
    }
   ],
   "source": [
    "from braindecode.torch_ext.util import np_to_var\n",
    "# determine output size\n",
    "test_input = np_to_var(np.ones((2, in_chans, input_time_length, 1), dtype=np.float32))\n",
    "if cuda:\n",
    "    test_input = test_input.cuda()\n",
    "out = model(test_input)\n",
    "n_preds_per_input = out.cpu().data.numpy().shape[2]\n",
    "print(\"{:d} predictions per input/trial\".format(n_preds_per_input))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "from braindecode.datautil.iterators import CropsFromTrialsIterator\n",
    "iterator = CropsFromTrialsIterator(batch_size=32,input_time_length=input_time_length,\n",
    "                                  n_preds_per_input=n_preds_per_input)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "The iterator has the method `get_batches`, which can be used to get randomly shuffled training batches with `shuffle=True` or ordered batches (i.e. first from trial 1, then from trial 2, etc.) with `shuffle=False`. Additionally, Braindecode provides the `compute_preds_per_trial_for_set` method, which accepts predictions from the ordered batches and returns predictions per trial. It removes any overlapping predictions, which occur if the number of predictions per input is not a divisor of the number of samples in a trial.\n",
    "\n",
    "\n",
    "<div class=\"alert alert-info\">\n",
    "\n",
    "These methods can also work with trials of different lengths! For different-length trials, set `X` to be a list of 2d-arrays instead of a 3d-array.\n",
    "\n",
    "</div>"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We now can set the optimizer, since we can compute the number of batches per epoch using the iterator."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "from braindecode.torch_ext.optimizers import AdamW\n",
    "from braindecode.torch_ext.schedulers import ScheduledOptimizer, CosineAnnealing\n",
    "from braindecode.datautil.iterators import get_balanced_batches\n",
    "from numpy.random import RandomState\n",
    "rng = RandomState((2018,8,7))\n",
    "#optimizer = AdamW(model.parameters(), lr=1*0.01, weight_decay=0.5*0.001) # these are good values for the deep model\n",
    "optimizer = AdamW(model.parameters(), lr=0.0625 * 0.01, weight_decay=0)\n",
    "# Need to determine number of batch passes per epoch for cosine annealing\n",
    "n_epochs = 30\n",
    "n_updates_per_epoch = len([None for b in iterator.get_batches(train_set, True)])\n",
    "scheduler = CosineAnnealing(n_epochs * n_updates_per_epoch)\n",
    "# schedule_weight_decay must be True for AdamW\n",
    "optimizer = ScheduledOptimizer(scheduler, optimizer, schedule_weight_decay=True)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Training loop"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "The code below uses both the cropped iterator and the `compute_preds_per_trial_from_crops` function to train and evaluate the network."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 0\n",
      "Train  Loss: 3.79010\n",
      "Train  Accuracy: 50.0%\n",
      "Valid  Loss: 3.12765\n",
      "Valid  Accuracy: 46.7%\n",
      "Epoch 1\n",
      "Train  Loss: 1.91778\n",
      "Train  Accuracy: 50.0%\n",
      "Valid  Loss: 1.52058\n",
      "Valid  Accuracy: 46.7%\n",
      "Epoch 2\n",
      "Train  Loss: 1.09943\n",
      "Train  Accuracy: 60.0%\n",
      "Valid  Loss: 0.99602\n",
      "Valid  Accuracy: 56.7%\n",
      "Epoch 3\n",
      "Train  Loss: 0.73015\n",
      "Train  Accuracy: 67.5%\n",
      "Valid  Loss: 0.82321\n",
      "Valid  Accuracy: 56.7%\n",
      "Epoch 4\n",
      "Train  Loss: 0.55965\n",
      "Train  Accuracy: 75.0%\n",
      "Valid  Loss: 0.77549\n",
      "Valid  Accuracy: 63.3%\n",
      "Epoch 5\n",
      "Train  Loss: 0.38075\n",
      "Train  Accuracy: 82.5%\n",
      "Valid  Loss: 0.65490\n",
      "Valid  Accuracy: 73.3%\n",
      "Epoch 6\n",
      "Train  Loss: 0.28886\n",
      "Train  Accuracy: 90.0%\n",
      "Valid  Loss: 0.59481\n",
      "Valid  Accuracy: 73.3%\n",
      "Epoch 7\n",
      "Train  Loss: 0.21871\n",
      "Train  Accuracy: 97.5%\n",
      "Valid  Loss: 0.52357\n",
      "Valid  Accuracy: 83.3%\n",
      "Epoch 8\n",
      "Train  Loss: 0.16713\n",
      "Train  Accuracy: 97.5%\n",
      "Valid  Loss: 0.44890\n",
      "Valid  Accuracy: 86.7%\n",
      "Epoch 9\n",
      "Train  Loss: 0.13149\n",
      "Train  Accuracy: 97.5%\n",
      "Valid  Loss: 0.40492\n",
      "Valid  Accuracy: 83.3%\n",
      "Epoch 10\n",
      "Train  Loss: 0.09581\n",
      "Train  Accuracy: 100.0%\n",
      "Valid  Loss: 0.36021\n",
      "Valid  Accuracy: 90.0%\n",
      "Epoch 11\n",
      "Train  Loss: 0.07818\n",
      "Train  Accuracy: 100.0%\n",
      "Valid  Loss: 0.34625\n",
      "Valid  Accuracy: 90.0%\n",
      "Epoch 12\n",
      "Train  Loss: 0.07454\n",
      "Train  Accuracy: 100.0%\n",
      "Valid  Loss: 0.34489\n",
      "Valid  Accuracy: 90.0%\n",
      "Epoch 13\n",
      "Train  Loss: 0.06694\n",
      "Train  Accuracy: 100.0%\n",
      "Valid  Loss: 0.33878\n",
      "Valid  Accuracy: 90.0%\n",
      "Epoch 14\n",
      "Train  Loss: 0.05971\n",
      "Train  Accuracy: 100.0%\n",
      "Valid  Loss: 0.33128\n",
      "Valid  Accuracy: 90.0%\n",
      "Epoch 15\n",
      "Train  Loss: 0.05269\n",
      "Train  Accuracy: 100.0%\n",
      "Valid  Loss: 0.32202\n",
      "Valid  Accuracy: 90.0%\n",
      "Epoch 16\n",
      "Train  Loss: 0.04354\n",
      "Train  Accuracy: 100.0%\n",
      "Valid  Loss: 0.31063\n",
      "Valid  Accuracy: 90.0%\n",
      "Epoch 17\n",
      "Train  Loss: 0.03759\n",
      "Train  Accuracy: 100.0%\n",
      "Valid  Loss: 0.30314\n",
      "Valid  Accuracy: 90.0%\n",
      "Epoch 18\n",
      "Train  Loss: 0.03401\n",
      "Train  Accuracy: 100.0%\n",
      "Valid  Loss: 0.29997\n",
      "Valid  Accuracy: 90.0%\n",
      "Epoch 19\n",
      "Train  Loss: 0.03145\n",
      "Train  Accuracy: 100.0%\n",
      "Valid  Loss: 0.29764\n",
      "Valid  Accuracy: 90.0%\n"
     ]
    }
   ],
   "source": [
    "from braindecode.torch_ext.util import np_to_var, var_to_np\n",
    "import torch.nn.functional as F\n",
    "from numpy.random import RandomState\n",
    "import torch as th\n",
    "from braindecode.experiments.monitors import compute_preds_per_trial_from_crops\n",
    "rng = RandomState((2017,6,30))\n",
    "for i_epoch in range(20):\n",
    "    # Set model to training mode\n",
    "    model.train()\n",
    "    for batch_X, batch_y in iterator.get_batches(train_set, shuffle=True):\n",
    "        net_in = np_to_var(batch_X)\n",
    "        if cuda:\n",
    "            net_in = net_in.cuda()\n",
    "        net_target = np_to_var(batch_y)\n",
    "        if cuda:\n",
    "            net_target = net_target.cuda()\n",
    "        # Remove gradients of last backward pass from all parameters \n",
    "        optimizer.zero_grad()\n",
    "        outputs = model(net_in)\n",
    "        # Mean predictions across trial\n",
    "        # Note that this will give identical gradients to computing\n",
    "        # a per-prediction loss (at least for the combination of log softmax activation \n",
    "        # and negative log likelihood loss which we are using here)\n",
    "        outputs = th.mean(outputs, dim=2, keepdim=False)\n",
    "        loss = F.nll_loss(outputs, net_target)\n",
    "        loss.backward()\n",
    "        optimizer.step()\n",
    "    \n",
    "    # Print some statistics each epoch\n",
    "    model.eval()\n",
    "    print(\"Epoch {:d}\".format(i_epoch))\n",
    "    for setname, dataset in (('Train', train_set),('Valid', valid_set)):\n",
    "        # Collect all predictions and losses\n",
    "        all_preds = []\n",
    "        all_losses = []\n",
    "        batch_sizes = []\n",
    "        for batch_X, batch_y in iterator.get_batches(dataset, shuffle=False):\n",
    "            net_in = np_to_var(batch_X)\n",
    "            if cuda:\n",
    "                net_in = net_in.cuda()\n",
    "            net_target = np_to_var(batch_y)\n",
    "            if cuda:\n",
    "                net_target = net_target.cuda()\n",
    "            outputs = model(net_in)\n",
    "            all_preds.append(var_to_np(outputs))\n",
    "            outputs = th.mean(outputs, dim=2, keepdim=False)\n",
    "            loss = F.nll_loss(outputs, net_target)\n",
    "            loss = float(var_to_np(loss))\n",
    "            all_losses.append(loss)\n",
    "            batch_sizes.append(len(batch_X))\n",
    "        # Compute mean per-input loss \n",
    "        loss = np.mean(np.array(all_losses) * np.array(batch_sizes) /\n",
    "                       np.mean(batch_sizes))\n",
    "        print(\"{:6s} Loss: {:.5f}\".format(setname, loss))\n",
    "        # Assign the predictions to the trials\n",
    "        preds_per_trial = compute_preds_per_trial_from_crops(all_preds,\n",
    "                                                          input_time_length,\n",
    "                                                          dataset.X)\n",
    "        # preds per trial are now trials x classes x timesteps/predictions\n",
    "        # Now mean across timesteps for each trial to get per-trial predictions\n",
    "        meaned_preds_per_trial = np.array([np.mean(p, axis=1) for p in preds_per_trial])\n",
    "        predicted_labels = np.argmax(meaned_preds_per_trial, axis=1)\n",
    "        accuracy = np.mean(predicted_labels == dataset.y)\n",
    "        print(\"{:6s} Accuracy: {:.1f}%\".format(\n",
    "            setname, accuracy * 100))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Eventually, we arrive at 90.0% accuracy, so 27 from 30 trials are correctly predicted, 5 more than for the trialwise decoding method."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Evaluation"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Once we have all our hyperparameters and architectural choices done, we can evaluate the accuracies to report in our publication by evaluating on the test set:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Test Loss: 0.42105\n",
      "Test Accuracy: 85.0%\n"
     ]
    }
   ],
   "source": [
    "test_set = SignalAndTarget(X[70:], y=y[70:])\n",
    "\n",
    "model.eval()\n",
    "# Collect all predictions and losses\n",
    "all_preds = []\n",
    "all_losses = []\n",
    "batch_sizes = []\n",
    "for batch_X, batch_y in iterator.get_batches(test_set, shuffle=False):\n",
    "    net_in = np_to_var(batch_X)\n",
    "    if cuda:\n",
    "        net_in = net_in.cuda()\n",
    "    net_target = np_to_var(batch_y)\n",
    "    if cuda:\n",
    "        net_target = net_target.cuda()\n",
    "    outputs = model(net_in)\n",
    "    all_preds.append(var_to_np(outputs))\n",
    "    outputs = th.mean(outputs, dim=2, keepdim=False)\n",
    "    loss = F.nll_loss(outputs, net_target)\n",
    "    loss = float(var_to_np(loss))\n",
    "    all_losses.append(loss)\n",
    "    batch_sizes.append(len(batch_X))\n",
    "# Compute mean per-input loss \n",
    "loss = np.mean(np.array(all_losses) * np.array(batch_sizes) /\n",
    "               np.mean(batch_sizes))\n",
    "print(\"Test Loss: {:.5f}\".format(loss))\n",
    "# Assign the predictions to the trials\n",
    "preds_per_trial = compute_preds_per_trial_from_crops(all_preds,\n",
    "                                                  input_time_length,\n",
    "                                                  test_set.X)\n",
    "# preds per trial are now trials x classes x timesteps/predictions\n",
    "# Now mean across timesteps for each trial to get per-trial predictions\n",
    "meaned_preds_per_trial = np.array([np.mean(p, axis=1) for p in preds_per_trial])\n",
    "predicted_labels = np.argmax(meaned_preds_per_trial, axis=1)\n",
    "accuracy = np.mean(predicted_labels == test_set.y)\n",
    "print(\"Test Accuracy: {:.1f}%\".format(accuracy * 100))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Dataset references\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    " This dataset was created and contributed to PhysioNet by the developers of the [BCI2000](http://www.schalklab.org/research/bci2000) instrumentation system, which they used in making these recordings. The system is described in:\n",
    " \n",
    "     Schalk, G., McFarland, D.J., Hinterberger, T., Birbaumer, N., Wolpaw, J.R. (2004) BCI2000: A General-Purpose Brain-Computer Interface (BCI) System. IEEE TBME 51(6):1034-1043.\n",
    "\n",
    "[PhysioBank](https://physionet.org/physiobank/) is a large and growing archive of well-characterized digital recordings of physiologic signals and related data for use by the biomedical research community and further described in:\n",
    "\n",
    "    Goldberger AL, Amaral LAN, Glass L, Hausdorff JM, Ivanov PCh, Mark RG, Mietus JE, Moody GB, Peng C-K, Stanley HE. (2000) PhysioBank, PhysioToolkit, and PhysioNet: Components of a New Research Resource for Complex Physiologic Signals. Circulation 101(23):e215-e220."
   ]
  }
 ],
 "metadata": {
  "git": {
   "keep_outputs": true
  },
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.6"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
